{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "deletable": true,
    "editable": true
   },
   "source": [
    "# Anna KaRNNa\n",
    "\n",
    "In this notebook, I'll build a character-wise RNN trained on Anna Karenina, one of my all-time favorite books. It'll be able to generate new text based on the text from the book.\n",
    "\n",
    "This network is based off of Andrej Karpathy's [post on RNNs](http://karpathy.github.io/2015/05/21/rnn-effectiveness/) and [implementation in Torch](https://github.com/karpathy/char-rnn). Also, some information [here at r2rt](http://r2rt.com/recurrent-neural-networks-in-tensorflow-ii.html) and from [Sherjil Ozair](https://github.com/sherjilozair/char-rnn-tensorflow) on GitHub. Below is the general architecture of the character-wise RNN.\n",
    "\n",
    "<img src=\"assets/charseq.jpeg\" width=\"500\">"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": false,
    "deletable": true,
    "editable": true
   },
   "outputs": [],
   "source": [
    "import time\n",
    "from collections import namedtuple\n",
    "\n",
    "import numpy as np\n",
    "import tensorflow as tf"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "deletable": true,
    "editable": true
   },
   "source": [
    "First we'll load the text file and convert it into integers for our network to use."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": false,
    "deletable": true,
    "editable": true
   },
   "outputs": [],
   "source": [
    "with open('anna.txt', 'r') as f:\n",
    "    text=f.read()\n",
    "vocab = set(text)\n",
    "vocab_to_int = {c: i for i, c in enumerate(vocab)}\n",
    "int_to_vocab = dict(enumerate(vocab))\n",
    "chars = np.array([vocab_to_int[c] for c in text], dtype=np.int32)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": false,
    "deletable": true,
    "editable": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'Chapter 1\\n\\n\\nHappy families are all alike; every unhappy family is unhappy in its own\\nway.\\n\\nEverythin'"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "text[:100]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": false,
    "deletable": true,
    "editable": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([18,  3,  5, 13, 39,  6, 58, 24, 59, 80, 80, 80, 31,  5, 13, 13, 19,\n",
       "       24, 14,  5, 40, 82, 49, 82,  6,  2, 24,  5, 58,  6, 24,  5, 49, 49,\n",
       "       24,  5, 49, 82, 43,  6, 50, 24,  6, 29,  6, 58, 19, 24, 69, 10,  3,\n",
       "        5, 13, 13, 19, 24, 14,  5, 40, 82, 49, 19, 24, 82,  2, 24, 69, 10,\n",
       "        3,  5, 13, 13, 19, 24, 82, 10, 24, 82, 39,  2, 24, 45, 51, 10, 80,\n",
       "       51,  5, 19, 17, 80, 80, 67, 29,  6, 58, 19, 39,  3, 82, 10], dtype=int32)"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "chars[:100]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "deletable": true,
    "editable": true
   },
   "source": [
    "Now I need to split up the data into batches, and into training and validation sets. I should be making a test set here, but I'm not going to worry about that. My test will be if the network can generate new text.\n",
    "\n",
    "Here I'll make both input and target arrays. The targets are the same as the inputs, except shifted one character over. I'll also drop the last bit of data so that I'll only have completely full batches.\n",
    "\n",
    "The idea here is to make a 2D matrix where the number of rows is equal to the number of batches. Each row will be one long concatenated string from the character data. We'll split this data into a training set and validation set using the `split_frac` keyword. This will keep 90% of the batches in the training set, the other 10% in the validation set."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "collapsed": true,
    "deletable": true,
    "editable": true
   },
   "outputs": [],
   "source": [
    "def split_data(chars, batch_size, num_steps, split_frac=0.9):\n",
    "    \"\"\" \n",
    "    Split character data into training and validation sets, inputs and targets for each set.\n",
    "    \n",
    "    Arguments\n",
    "    ---------\n",
    "    chars: character array\n",
    "    batch_size: Size of examples in each of batch\n",
    "    num_steps: Number of sequence steps to keep in the input and pass to the network\n",
    "    split_frac: Fraction of batches to keep in the training set\n",
    "    \n",
    "    \n",
    "    Returns train_x, train_y, val_x, val_y\n",
    "    \"\"\"\n",
    "    \n",
    "    \n",
    "    slice_size = batch_size * num_steps\n",
    "    n_batches = int(len(chars) / slice_size)\n",
    "    \n",
    "    # Drop the last few characters to make only full batches\n",
    "    x = chars[: n_batches*slice_size]\n",
    "    y = chars[1: n_batches*slice_size + 1]\n",
    "    \n",
    "    # Split the data into batch_size slices, then stack them into a 2D matrix \n",
    "    x = np.stack(np.split(x, batch_size))\n",
    "    y = np.stack(np.split(y, batch_size))\n",
    "    \n",
    "    # Now x and y are arrays with dimensions batch_size x n_batches*num_steps\n",
    "    \n",
    "    # Split into training and validation sets, keep the virst split_frac batches for training\n",
    "    split_idx = int(n_batches*split_frac)\n",
    "    train_x, train_y= x[:, :split_idx*num_steps], y[:, :split_idx*num_steps]\n",
    "    val_x, val_y = x[:, split_idx*num_steps:], y[:, split_idx*num_steps:]\n",
    "    \n",
    "    return train_x, train_y, val_x, val_y"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "collapsed": false,
    "deletable": true,
    "editable": true
   },
   "outputs": [],
   "source": [
    "train_x, train_y, val_x, val_y = split_data(chars, 10, 200)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "collapsed": false,
    "deletable": true,
    "editable": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(10, 178400)"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_x.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "collapsed": false,
    "deletable": true,
    "editable": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[18,  3,  5, 13, 39,  6, 58, 24, 59, 80],\n",
       "       [57, 10, 23, 24,  3,  6, 24, 40, 45, 29],\n",
       "       [24, 16,  5, 39, 16,  3, 82, 10, 73, 24],\n",
       "       [45, 39,  3,  6, 58, 24, 51, 45, 69, 49],\n",
       "       [24, 39,  3,  6, 24, 49,  5, 10, 23, 54],\n",
       "       [24, 62,  3, 58, 45, 69, 73,  3, 24, 49],\n",
       "       [39, 24, 39, 45, 80, 23, 45, 17, 80, 80],\n",
       "       [45, 24,  3,  6, 58,  2,  6, 49, 14, 64],\n",
       "       [ 3,  5, 39, 24, 82,  2, 24, 39,  3,  6],\n",
       "       [ 6, 58,  2,  6, 49, 14, 24,  5, 10, 23]], dtype=int32)"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_x[:,:10]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "deletable": true,
    "editable": true
   },
   "source": [
    "I'll write another function to grab batches out of the arrays made by split data. Here each batch will be a sliding window on these arrays with size `batch_size X num_steps`. For example, if we want our network to train on a sequence of 100 characters, `num_steps = 100`. For the next batch, we'll shift this window the next sequence of `num_steps` characters. In this way we can feed batches to the network and the cell states will continue through on each batch."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "collapsed": true,
    "deletable": true,
    "editable": true
   },
   "outputs": [],
   "source": [
    "def get_batch(arrs, num_steps):\n",
    "    batch_size, slice_size = arrs[0].shape\n",
    "    \n",
    "    n_batches = int(slice_size/num_steps)\n",
    "    for b in range(n_batches):\n",
    "        yield [x[:, b*num_steps: (b+1)*num_steps] for x in arrs]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "collapsed": false,
    "deletable": true,
    "editable": true
   },
   "outputs": [],
   "source": [
    "def build_rnn(num_classes, batch_size=50, num_steps=50, lstm_size=128, num_layers=2,\n",
    "              learning_rate=0.001, grad_clip=5, sampling=False):\n",
    "        \n",
    "    if sampling == True:\n",
    "        batch_size, num_steps = 1, 1\n",
    "\n",
    "    tf.reset_default_graph()\n",
    "    \n",
    "    # Declare placeholders we'll feed into the graph\n",
    "    \n",
    "    inputs = tf.placeholder(tf.int32, [batch_size, num_steps], name='inputs')\n",
    "    x_one_hot = tf.one_hot(inputs, num_classes, name='x_one_hot')\n",
    "\n",
    "\n",
    "    targets = tf.placeholder(tf.int32, [batch_size, num_steps], name='targets')\n",
    "    y_one_hot = tf.one_hot(targets, num_classes, name='y_one_hot')\n",
    "    y_reshaped = tf.reshape(y_one_hot, [-1, num_classes])\n",
    "    \n",
    "    keep_prob = tf.placeholder(tf.float32, name='keep_prob')\n",
    "    \n",
    "    # Build the RNN layers\n",
    "    \n",
    "    lstm = tf.contrib.rnn.BasicLSTMCell(lstm_size)\n",
    "    drop = tf.contrib.rnn.DropoutWrapper(lstm, output_keep_prob=keep_prob)\n",
    "    cell = tf.contrib.rnn.MultiRNNCell([drop] * num_layers)\n",
    "\n",
    "    initial_state = cell.zero_state(batch_size, tf.float32)\n",
    "\n",
    "    # Run the data through the RNN layers\n",
    "    outputs, state = tf.nn.dynamic_rnn(cell, x_one_hot, initial_state=initial_state)\n",
    "    final_state = state\n",
    "    \n",
    "    # Reshape output so it's a bunch of rows, one row for each cell output\n",
    "    \n",
    "    seq_output = tf.concat(outputs, axis=1,name='seq_output')\n",
    "    output = tf.reshape(seq_output, [-1, lstm_size], name='graph_output')\n",
    "    \n",
    "    # Now connect the RNN putputs to a softmax layer and calculate the cost\n",
    "    softmax_w = tf.Variable(tf.truncated_normal((lstm_size, num_classes), stddev=0.1),\n",
    "                           name='softmax_w')\n",
    "    softmax_b = tf.Variable(tf.zeros(num_classes), name='softmax_b')\n",
    "    logits = tf.matmul(output, softmax_w) + softmax_b\n",
    "\n",
    "    preds = tf.nn.softmax(logits, name='predictions')\n",
    "    \n",
    "    loss = tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=y_reshaped, name='loss')\n",
    "    cost = tf.reduce_mean(loss, name='cost')\n",
    "\n",
    "    # Optimizer for training, using gradient clipping to control exploding gradients\n",
    "    tvars = tf.trainable_variables()\n",
    "    grads, _ = tf.clip_by_global_norm(tf.gradients(cost, tvars), grad_clip)\n",
    "    train_op = tf.train.AdamOptimizer(learning_rate)\n",
    "    optimizer = train_op.apply_gradients(zip(grads, tvars))\n",
    "\n",
    "    # Export the nodes \n",
    "    export_nodes = ['inputs', 'targets', 'initial_state', 'final_state',\n",
    "                    'keep_prob', 'cost', 'preds', 'optimizer']\n",
    "    Graph = namedtuple('Graph', export_nodes)\n",
    "    local_dict = locals()\n",
    "    graph = Graph(*[local_dict[each] for each in export_nodes])\n",
    "    \n",
    "    return graph"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "deletable": true,
    "editable": true
   },
   "source": [
    "## Hyperparameters\n",
    "\n",
    "Here I'm defining the hyperparameters for the network. The two you probably haven't seen before are `lstm_size` and `num_layers`. These set the number of hidden units in the LSTM layers and the number of LSTM layers, respectively. Of course, making these bigger will improve the network's performance but you'll have to watch out for overfitting. If your validation loss is much larger than the training loss, you're probably overfitting. Decrease the size of the network or decrease the dropout keep probability."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "collapsed": false,
    "deletable": true,
    "editable": true
   },
   "outputs": [],
   "source": [
    "batch_size = 100\n",
    "num_steps = 100\n",
    "lstm_size = 512\n",
    "num_layers = 2\n",
    "learning_rate = 0.001"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "deletable": true,
    "editable": true
   },
   "source": [
    "## Write out the graph for TensorBoard"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "collapsed": false,
    "deletable": true,
    "editable": true
   },
   "outputs": [],
   "source": [
    "model = build_rnn(len(vocab),\n",
    "                  batch_size=batch_size,\n",
    "                  num_steps=num_steps,\n",
    "                  learning_rate=learning_rate,\n",
    "                  lstm_size=lstm_size,\n",
    "                  num_layers=num_layers)\n",
    "\n",
    "with tf.Session() as sess:\n",
    "    \n",
    "    sess.run(tf.global_variables_initializer())\n",
    "    file_writer = tf.summary.FileWriter('./logs/1', sess.graph)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "deletable": true,
    "editable": true
   },
   "source": [
    "## Training\n",
    "\n",
    "Time for training which is is pretty straightforward. Here I pass in some data, and get an LSTM state back. Then I pass that state back in to the network so the next batch can continue the state from the previous batch. And every so often (set by `save_every_n`) I calculate the validation loss and save a checkpoint."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "collapsed": true,
    "deletable": true,
    "editable": true
   },
   "outputs": [],
   "source": [
    "!mkdir -p checkpoints/anna"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {
    "collapsed": false,
    "deletable": true,
    "editable": true,
    "scrolled": true
   },
   "outputs": [
    {
     "ename": "ValueError",
     "evalue": "Expected state to be a tuple of length 2, but received: Tensor(\"initial_state:0\", shape=(2, 2, 100, 512), dtype=float32)",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mValueError\u001b[0m                                Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-20-4190d11347ea>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m      8\u001b[0m                   \u001b[0mlearning_rate\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mlearning_rate\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      9\u001b[0m                   \u001b[0mlstm_size\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mlstm_size\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 10\u001b[0;31m                   num_layers=num_layers)\n\u001b[0m\u001b[1;32m     11\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     12\u001b[0m \u001b[0msaver\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrain\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mSaver\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmax_to_keep\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m100\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m<ipython-input-19-a7e675cc0f3d>\u001b[0m in \u001b[0;36mbuild_rnn\u001b[0;34m(num_classes, batch_size, num_steps, lstm_size, num_layers, learning_rate, grad_clip, sampling)\u001b[0m\n\u001b[1;32m     25\u001b[0m     \u001b[0;31m# Run the data through the RNN layers\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     26\u001b[0m     \u001b[0mrnn_inputs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mtf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msqueeze\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msqueeze_dims\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mi\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mtf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msplit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx_one_hot\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnum_steps\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 27\u001b[0;31m     \u001b[0moutputs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstate\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcontrib\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrnn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstatic_rnn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcell\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrnn_inputs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minitial_state\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0minitial_state\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     28\u001b[0m     \u001b[0mfinal_state\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0midentity\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mstate\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mname\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m'final_state'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     29\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/home/mat/miniconda3/envs/tf-gpu/lib/python3.5/site-packages/tensorflow/contrib/rnn/python/ops/core_rnn.py\u001b[0m in \u001b[0;36mstatic_rnn\u001b[0;34m(cell, inputs, initial_state, dtype, sequence_length, scope)\u001b[0m\n\u001b[1;32m    195\u001b[0m             state_size=cell.state_size)\n\u001b[1;32m    196\u001b[0m       \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 197\u001b[0;31m         \u001b[0;34m(\u001b[0m\u001b[0moutput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstate\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcall_cell\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    198\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    199\u001b[0m       \u001b[0moutputs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0moutput\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/home/mat/miniconda3/envs/tf-gpu/lib/python3.5/site-packages/tensorflow/contrib/rnn/python/ops/core_rnn.py\u001b[0m in \u001b[0;36m<lambda>\u001b[0;34m()\u001b[0m\n\u001b[1;32m    182\u001b[0m       \u001b[0;32mif\u001b[0m \u001b[0mtime\u001b[0m \u001b[0;34m>\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mvarscope\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreuse_variables\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    183\u001b[0m       \u001b[0;31m# pylint: disable=cell-var-from-loop\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 184\u001b[0;31m       \u001b[0mcall_cell\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mlambda\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mcell\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput_\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstate\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    185\u001b[0m       \u001b[0;31m# pylint: enable=cell-var-from-loop\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    186\u001b[0m       \u001b[0;32mif\u001b[0m \u001b[0msequence_length\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/home/mat/miniconda3/envs/tf-gpu/lib/python3.5/site-packages/tensorflow/contrib/rnn/python/ops/core_rnn_cell_impl.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, inputs, state, scope)\u001b[0m\n\u001b[1;32m    647\u001b[0m               raise ValueError(\n\u001b[1;32m    648\u001b[0m                   \u001b[0;34m\"Expected state to be a tuple of length %d, but received: %s\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 649\u001b[0;31m                   % (len(self.state_size), state))\n\u001b[0m\u001b[1;32m    650\u001b[0m             \u001b[0mcur_state\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mstate\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    651\u001b[0m           \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mValueError\u001b[0m: Expected state to be a tuple of length 2, but received: Tensor(\"initial_state:0\", shape=(2, 2, 100, 512), dtype=float32)"
     ]
    }
   ],
   "source": [
    "epochs = 1\n",
    "save_every_n = 200\n",
    "train_x, train_y, val_x, val_y = split_data(chars, batch_size, num_steps)\n",
    "\n",
    "model = build_rnn(len(vocab), \n",
    "                  batch_size=batch_size,\n",
    "                  num_steps=num_steps,\n",
    "                  learning_rate=learning_rate,\n",
    "                  lstm_size=lstm_size,\n",
    "                  num_layers=num_layers)\n",
    "\n",
    "saver = tf.train.Saver(max_to_keep=100)\n",
    "\n",
    "with tf.Session() as sess:\n",
    "    sess.run(tf.global_variables_initializer())\n",
    "    \n",
    "    # Use the line below to load a checkpoint and resume training\n",
    "    #saver.restore(sess, 'checkpoints/anna20.ckpt')\n",
    "    \n",
    "    n_batches = int(train_x.shape[1]/num_steps)\n",
    "    iterations = n_batches * epochs\n",
    "    for e in range(epochs):\n",
    "        \n",
    "        # Train network\n",
    "        new_state = sess.run(model.initial_state)\n",
    "        loss = 0\n",
    "        for b, (x, y) in enumerate(get_batch([train_x, train_y], num_steps), 1):\n",
    "            iteration = e*n_batches + b\n",
    "            start = time.time()\n",
    "            feed = {model.inputs: x,\n",
    "                    model.targets: y,\n",
    "                    model.keep_prob: 0.5,\n",
    "                    model.initial_state: new_state}\n",
    "            batch_loss, new_state, _ = sess.run([model.cost, model.final_state, model.optimizer], \n",
    "                                                 feed_dict=feed)\n",
    "            loss += batch_loss\n",
    "            end = time.time()\n",
    "            print('Epoch {}/{} '.format(e+1, epochs),\n",
    "                  'Iteration {}/{}'.format(iteration, iterations),\n",
    "                  'Training loss: {:.4f}'.format(loss/b),\n",
    "                  '{:.4f} sec/batch'.format((end-start)))\n",
    "        \n",
    "            \n",
    "            if (iteration%save_every_n == 0) or (iteration == iterations):\n",
    "                # Check performance, notice dropout has been set to 1\n",
    "                val_loss = []\n",
    "                new_state = sess.run(model.initial_state)\n",
    "                for x, y in get_batch([val_x, val_y], num_steps):\n",
    "                    feed = {model.inputs: x,\n",
    "                            model.targets: y,\n",
    "                            model.keep_prob: 1.,\n",
    "                            model.initial_state: new_state}\n",
    "                    batch_loss, new_state = sess.run([model.cost, model.final_state], feed_dict=feed)\n",
    "                    val_loss.append(batch_loss)\n",
    "\n",
    "                print('Validation loss:', np.mean(val_loss),\n",
    "                      'Saving checkpoint!')\n",
    "                saver.save(sess, \"checkpoints/anna/i{}_l{}_{:.3f}.ckpt\".format(iteration, lstm_size, np.mean(val_loss)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {
    "collapsed": false,
    "deletable": true,
    "editable": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "model_checkpoint_path: \"checkpoints/anna/i3560_l512_1.122.ckpt\"\n",
       "all_model_checkpoint_paths: \"checkpoints/anna/i200_l512_2.432.ckpt\"\n",
       "all_model_checkpoint_paths: \"checkpoints/anna/i400_l512_1.980.ckpt\"\n",
       "all_model_checkpoint_paths: \"checkpoints/anna/i600_l512_1.750.ckpt\"\n",
       "all_model_checkpoint_paths: \"checkpoints/anna/i800_l512_1.595.ckpt\"\n",
       "all_model_checkpoint_paths: \"checkpoints/anna/i1000_l512_1.484.ckpt\"\n",
       "all_model_checkpoint_paths: \"checkpoints/anna/i1200_l512_1.407.ckpt\"\n",
       "all_model_checkpoint_paths: \"checkpoints/anna/i1400_l512_1.349.ckpt\"\n",
       "all_model_checkpoint_paths: \"checkpoints/anna/i1600_l512_1.292.ckpt\"\n",
       "all_model_checkpoint_paths: \"checkpoints/anna/i1800_l512_1.255.ckpt\"\n",
       "all_model_checkpoint_paths: \"checkpoints/anna/i2000_l512_1.224.ckpt\"\n",
       "all_model_checkpoint_paths: \"checkpoints/anna/i2200_l512_1.204.ckpt\"\n",
       "all_model_checkpoint_paths: \"checkpoints/anna/i2400_l512_1.187.ckpt\"\n",
       "all_model_checkpoint_paths: \"checkpoints/anna/i2600_l512_1.172.ckpt\"\n",
       "all_model_checkpoint_paths: \"checkpoints/anna/i2800_l512_1.160.ckpt\"\n",
       "all_model_checkpoint_paths: \"checkpoints/anna/i3000_l512_1.148.ckpt\"\n",
       "all_model_checkpoint_paths: \"checkpoints/anna/i3200_l512_1.137.ckpt\"\n",
       "all_model_checkpoint_paths: \"checkpoints/anna/i3400_l512_1.129.ckpt\"\n",
       "all_model_checkpoint_paths: \"checkpoints/anna/i3560_l512_1.122.ckpt\""
      ]
     },
     "execution_count": 35,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tf.train.get_checkpoint_state('checkpoints/anna')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "deletable": true,
    "editable": true
   },
   "source": [
    "## Sampling\n",
    "\n",
    "Now that the network is trained, we'll can use it to generate new text. The idea is that we pass in a character, then the network will predict the next character. We can use the new one, to predict the next one. And we keep doing this to generate all new text. I also included some functionality to prime the network with some text by passing in a string and building up a state from that.\n",
    "\n",
    "The network gives us predictions for each character. To reduce noise and make things a little less random, I'm going to only choose a new character from the top N most likely characters.\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {
    "collapsed": true,
    "deletable": true,
    "editable": true
   },
   "outputs": [],
   "source": [
    "def pick_top_n(preds, vocab_size, top_n=5):\n",
    "    p = np.squeeze(preds)\n",
    "    p[np.argsort(p)[:-top_n]] = 0\n",
    "    p = p / np.sum(p)\n",
    "    c = np.random.choice(vocab_size, 1, p=p)[0]\n",
    "    return c"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {
    "collapsed": true,
    "deletable": true,
    "editable": true
   },
   "outputs": [],
   "source": [
    "def sample(checkpoint, n_samples, lstm_size, vocab_size, prime=\"The \"):\n",
    "    prime = \"Far\"\n",
    "    samples = [c for c in prime]\n",
    "    model = build_rnn(vocab_size, lstm_size=lstm_size, sampling=True)\n",
    "    saver = tf.train.Saver()\n",
    "    with tf.Session() as sess:\n",
    "        saver.restore(sess, checkpoint)\n",
    "        new_state = sess.run(model.initial_state)\n",
    "        for c in prime:\n",
    "            x = np.zeros((1, 1))\n",
    "            x[0,0] = vocab_to_int[c]\n",
    "            feed = {model.inputs: x,\n",
    "                    model.keep_prob: 1.,\n",
    "                    model.initial_state: new_state}\n",
    "            preds, new_state = sess.run([model.preds, model.final_state], \n",
    "                                         feed_dict=feed)\n",
    "\n",
    "        c = pick_top_n(preds, len(vocab))\n",
    "        samples.append(int_to_vocab[c])\n",
    "\n",
    "        for i in range(n_samples):\n",
    "            x[0,0] = c\n",
    "            feed = {model.inputs: x,\n",
    "                    model.keep_prob: 1.,\n",
    "                    model.initial_state: new_state}\n",
    "            preds, new_state = sess.run([model.preds, model.final_state], \n",
    "                                         feed_dict=feed)\n",
    "\n",
    "            c = pick_top_n(preds, len(vocab))\n",
    "            samples.append(int_to_vocab[c])\n",
    "        \n",
    "    return ''.join(samples)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "metadata": {
    "collapsed": false,
    "deletable": true,
    "editable": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Farlathit that if had so\n",
      "like it that it were. He could not trouble to his wife, and there was\n",
      "anything in them of the side of his weaky in the creature at his forteren\n",
      "to him.\n",
      "\n",
      "\"What is it? I can't bread to those,\" said Stepan Arkadyevitch. \"It's not\n",
      "my children, and there is an almost this arm, true it mays already,\n",
      "and tell you what I have say to you, and was not looking at the peasant,\n",
      "why is, I don't know him out, and she doesn't speak to me immediately, as\n",
      "you would say the countess and the more frest an angelembre, and time and\n",
      "things's silent, but I was not in my stand that is in my head. But if he\n",
      "say, and was so feeling with his soul. A child--in his soul of his\n",
      "soul of his soul. He should not see that any of that sense of. Here he\n",
      "had not been so composed and to speak for as in a whole picture, but\n",
      "all the setting and her excellent and society, who had been delighted\n",
      "and see to anywing had been being troed to thousand words on them,\n",
      "we liked him.\n",
      "\n",
      "That set in her money at the table, he came into the party. The capable\n",
      "of his she could not be as an old composure.\n",
      "\n",
      "\"That's all something there will be down becime by throe is\n",
      "such a silent, as in a countess, I should state it out and divorct.\n",
      "The discussion is not for me. I was that something was simply they are\n",
      "all three manshess of a sensitions of mind it all.\"\n",
      "\n",
      "\"No,\" he thought, shouted and lifting his soul. \"While it might see your\n",
      "honser and she, I could burst. And I had been a midelity. And I had a\n",
      "marnief are through the countess,\" he said, looking at him, a chosing\n",
      "which they had been carried out and still solied, and there was a sen that\n",
      "was to be completely, and that this matter of all the seconds of it, and\n",
      "a concipation were to her husband, who came up and conscaously, that he\n",
      "was not the station. All his fourse she was always at the country,,\n",
      "to speak oft, and though they were to hear the delightful throom and\n",
      "whether they came towards the morning, and his living and a coller and\n",
      "hold--the children. \n"
     ]
    }
   ],
   "source": [
    "checkpoint = \"checkpoints/anna/i3560_l512_1.122.ckpt\"\n",
    "samp = sample(checkpoint, 2000, lstm_size, len(vocab), prime=\"Far\")\n",
    "print(samp)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "metadata": {
    "collapsed": false,
    "deletable": true,
    "editable": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Farnt him oste wha sorind thans tout thint asd an sesand an hires on thime sind thit aled, ban thand and out hore as the ter hos ton ho te that, was tis tart al the hand sostint him sore an tit an son thes, win he se ther san ther hher tas tarereng,.\n",
      "\n",
      "Anl at an ades in ond hesiln, ad hhe torers teans, wast tar arering tho this sos alten sorer has hhas an siton ther him he had sin he ard ate te anling the sosin her ans and\n",
      "arins asd and ther ale te tot an tand tanginge wath and ho ald, so sot th asend sat hare sother horesinnd, he hesense wing ante her so tith tir sherinn, anded and to the toul anderin he sorit he torsith she se atere an ting ot hand and thit hhe so the te wile har\n",
      "ens ont in the sersise, and we he seres tar aterer, to ato tat or has he he wan ton here won and sen heren he sosering, to to theer oo adent har herere the wosh oute, was serild ward tous hed astend..\n",
      "\n",
      "I's sint on alt in har tor tit her asd hade shithans ored he talereng an soredendere tim tot hees. Tise sor and \n"
     ]
    }
   ],
   "source": [
    "checkpoint = \"checkpoints/anna/i200_l512_2.432.ckpt\"\n",
    "samp = sample(checkpoint, 1000, lstm_size, len(vocab), prime=\"Far\")\n",
    "print(samp)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "metadata": {
    "collapsed": false,
    "deletable": true,
    "editable": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Fard as astice her said he celatice of to seress in the raice, and to be the some and sere allats to that said to that the sark and a cast a the wither ald the pacinesse of her had astition, he said to the sount as she west at hissele. Af the cond it he was a fact onthis astisarianing.\n",
      "\n",
      "\n",
      "\"Or a ton to to be that's a more at aspestale as the sont of anstiring as\n",
      "thours and trey.\n",
      "\n",
      "The same wo dangring the\n",
      "raterst, who sore and somethy had ast out an of his book. \"We had's beane were that, and a morted a thay he had to tere. Then to\n",
      "her homent andertersed his his ancouted to the pirsted, the soution for of the pirsice inthirgest and stenciol, with the hard and and\n",
      "a colrice of to be oneres,\n",
      "the song to this anderssad.\n",
      "The could ounterss the said to serom of\n",
      "soment a carsed of sheres of she\n",
      "torded\n",
      "har and want in their of hould, but\n",
      "her told in that in he tad a the same to her. Serghing an her has and with the seed, and the camt ont his about of the\n",
      "sail, the her then all houg ant or to hus to \n"
     ]
    }
   ],
   "source": [
    "checkpoint = \"checkpoints/anna/i600_l512_1.750.ckpt\"\n",
    "samp = sample(checkpoint, 1000, lstm_size, len(vocab), prime=\"Far\")\n",
    "print(samp)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "metadata": {
    "collapsed": false,
    "deletable": true,
    "editable": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Farrat, his felt has at it.\n",
      "\n",
      "\"When the pose ther hor exceed\n",
      "to his sheant was,\" weat a sime of his sounsed. The coment and the facily that which had began terede a marilicaly whice whether the pose of his hand, at she was alligated herself the same on she had to\n",
      "taiking to his forthing and streath how to hand\n",
      "began in a lang at some at it, this he cholded not set all her. \"Wo love that is setthing. Him anstering as seen that.\"\n",
      "\n",
      "\"Yes in the man that say the mare a crances is it?\" said Sergazy Ivancatching. \"You doon think were somether is ifficult of a mone of\n",
      "though the most at the countes that the\n",
      "mean on the come to say the most, to\n",
      "his feesing of\n",
      "a man she, whilo he\n",
      "sained and well, that he would still at to said. He wind at his for the sore in the most\n",
      "of hoss and almoved to see him. They have betine the sumper into at he his stire, and what he was that at the so steate of the\n",
      "sound, and shin should have a geest of shall feet on the conderation to she had been at that imporsing the dre\n"
     ]
    }
   ],
   "source": [
    "checkpoint = \"checkpoints/anna/i1000_l512_1.484.ckpt\"\n",
    "samp = sample(checkpoint, 1000, lstm_size, len(vocab), prime=\"Far\")\n",
    "print(samp)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
